{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(\n",
      "  din=2,\n",
      "  dout=2\n",
      ")\n",
      "[[0.63114893 1.2928092 ]\n",
      " [0.63114893 1.2928092 ]]\n"
     ]
    }
   ],
   "source": [
    "from flax.experimental import nnx\n",
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "class Linear(nnx.Module):\n",
    "\n",
    "  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n",
    "    # static attributes\n",
    "    self.din = din\n",
    "    self.dout = dout\n",
    "    # variables\n",
    "    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n",
    "    self.b = nnx.Param(jnp.zeros((dout,)))\n",
    "\n",
    "  def __call__(self, x):\n",
    "    return x @ self.w + self.b\n",
    "\n",
    "\n",
    "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n",
    "\n",
    "y = linear(jnp.ones((2, 2)))\n",
    "\n",
    "print(linear)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State({\n",
      "  'b': Param(\n",
      "    value=Array([0., 0.], dtype=float32)\n",
      "  ),\n",
      "  'w': Param(\n",
      "    value=Array([[0.31696808, 0.55285215],\n",
      "           [0.31418085, 0.7399571 ]], dtype=float32)\n",
      "  )\n",
      "})\n",
      "ModuleDef(\n",
      "  type=Linear,\n",
      "  index=0,\n",
      "  static_fields=(('din', 2), ('dout', 2)),\n",
      "  variables=(('b', Param(\n",
      "      value=Empty\n",
      "    )), ('w', Param(\n",
      "      value=Empty\n",
      "    ))),\n",
      "  submodules=()\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "state, moduledef = linear.split()\n",
    "\n",
    "print(state)\n",
    "print(moduledef)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(\n",
      "  din=2,\n",
      "  dout=2,\n",
      "  submodule=Linear(...)\n",
      ")\n",
      "[[0.63114893 1.2928092 ]\n",
      " [0.63114893 1.2928092 ]]\n"
     ]
    }
   ],
   "source": [
    "class Linear(nnx.Module):\n",
    "\n",
    "  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n",
    "    self.din = din\n",
    "    self.dout = dout\n",
    "    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n",
    "    self.b = nnx.Param(jnp.zeros((dout,)))\n",
    "    # introduce a self-reference\n",
    "    self.submodule = self\n",
    "\n",
    "  def __call__(self, x):\n",
    "    return x @ self.submodule.w + self.submodule.b\n",
    "\n",
    "\n",
    "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n",
    "\n",
    "y = linear(jnp.ones((2, 2)))\n",
    "\n",
    "print(linear)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State({\n",
      "  'b': Param(\n",
      "    value=Array([0., 0.], dtype=float32)\n",
      "  ),\n",
      "  'w': Param(\n",
      "    value=Array([[0.31696808, 0.55285215],\n",
      "           [0.31418085, 0.7399571 ]], dtype=float32)\n",
      "  )\n",
      "})\n",
      "ModuleDef(\n",
      "  type=Linear,\n",
      "  index=0,\n",
      "  static_fields=(('din', 2), ('dout', 2)),\n",
      "  variables=(('b', Param(\n",
      "      value=Empty\n",
      "    )), ('w', Param(\n",
      "      value=Empty\n",
      "    ))),\n",
      "  submodules=(\n",
      "    ('submodule', 0)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "state, moduledef = linear.split()\n",
    "\n",
    "print(state)\n",
    "print(moduledef)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear2 = moduledef.merge(state)\n",
    "\n",
    "linear2.submodule is linear2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(\n",
      "  din=2,\n",
      "  dout=2\n",
      ")\n",
      "[[0.63114893 1.2928092 ]\n",
      " [0.63114893 1.2928092 ]]\n"
     ]
    }
   ],
   "source": [
    "class Linear(nnx.Module):\n",
    "\n",
    "  def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n",
    "    # static attributes\n",
    "    self.din = din\n",
    "    self.dout = dout\n",
    "    # variables\n",
    "    self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n",
    "    self.b = nnx.Param(jnp.zeros((dout,)))\n",
    "\n",
    "  def __call__(self, x):\n",
    "    y = x @ self.w + self.b\n",
    "    self.y = nnx.Intermediate(y)\n",
    "    return y\n",
    "\n",
    "\n",
    "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n",
    "\n",
    "y = linear(jnp.ones((2, 2)))\n",
    "\n",
    "print(linear)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "State({\n",
      "  'y': Intermediate(\n",
      "    value=Array([[0.63114893, 1.2928092 ],\n",
      "           [0.63114893, 1.2928092 ]], dtype=float32)\n",
      "  )\n",
      "})\n",
      "State({\n",
      "  'b': Param(\n",
      "    value=Array([0., 0.], dtype=float32)\n",
      "  ),\n",
      "  'w': Param(\n",
      "    value=Array([[0.31696808, 0.55285215],\n",
      "           [0.31418085, 0.7399571 ]], dtype=float32)\n",
      "  )\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "intermediates = linear.pop(nnx.Intermediate)\n",
    "state, moduledef = linear.split()\n",
    "\n",
    "print(intermediates)\n",
    "print(state)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
